from pathlib import Path
from typing import List, Dict
from itertools import product
from dataclasses import dataclass, field


@dataclass
class NeedleConfig:
    keywords: List[str]
    sizes: List[float]
    posns: List[float]
    mode: str = "insert"
    corpus: List[str] = field(init=False)

    def __post_init__(self):
        assert self.mode in ["insert", "remove"], "Mode must be 'insert' or 'remove'."

        self.corpus = [self.get_corpus(keyword) for keyword in self.keywords]

    @staticmethod
    def get_corpus(keyword):
        if keyword is None:
            return None
        if "lorem" in keyword:
            return Path("_datasets/needles/lorem.txt").read_text("utf-8")
        raise ValueError("needle keyword not recognized")

    def get_configs(self) -> List[Dict]:
        """
        Generate all combinations of needle, size, and position
        """
        configs = []
        for keyword, size, posn in product(self.keywords, self.sizes, self.posns):
            assert -1 <= size <= 3, "Needle sizes are pcts (-1 to 3)"

            name = f"needle_{self.mode}_{posn}pos_{size}sz"
            if keyword is not None:
                name = f"needle_{self.mode}_{keyword}_{posn}pos_{size}sz"

            corpus = self.get_corpus(keyword)

            config = {
                "name": name,
                "mode": self.mode,
                "params": {
                    "corpus": corpus,
                    "length": size,
                    "posn": posn,
                },
            }
            configs.append(config)
        return configs
